{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from rotary_embedding_tensorflow import apply_rotary_emb, RotaryEmbedding\n",
    "from fast_transformer import FastTransformer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from malaya.text.bpe import WordPieceTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = WordPieceTokenizer('BERT.wordpiece', do_lower_case = False)\n",
    "# tokenizer.tokenize('halo nama sayacomel')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "with open('entities-fastformer.pkl', 'rb') as fopen:\n",
    "    train_X, train_Y, test_X, test_Y = pickle.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch = 2\n",
    "batch_size = 32\n",
    "warmup_proportion = 0.1\n",
    "num_train_steps = int(len(train_X) / batch_size * epoch)\n",
    "num_warmup_steps = int(num_train_steps * warmup_proportion)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/malaya/malaya/pretrained-model/fastformer/optimization.py:87: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_initializer(initializer_range=0.02):\n",
    "    return tf.truncated_normal_initializer(stddev=initializer_range)\n",
    "\n",
    "class Model:\n",
    "    def __init__(\n",
    "        self,\n",
    "        dimension_output,\n",
    "        learning_rate = 2e-5,\n",
    "        training = True,\n",
    "    ):\n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        mask = tf.math.not_equal(self.X, 0)\n",
    "        mask = tf.cast(mask, tf.bool)\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        self.maxlen = tf.shape(self.X)[1]\n",
    "        self.lengths = tf.count_nonzero(self.X, 1)\n",
    "        \n",
    "        self.model = FastTransformer(\n",
    "            num_tokens = 32000,\n",
    "            dim = 768,\n",
    "            depth = 12,\n",
    "            heads = 12,\n",
    "            max_seq_len = 2048,\n",
    "            absolute_pos_emb = True,\n",
    "            mask = mask\n",
    "        )\n",
    "        self.logits = self.model(self.X)[0]\n",
    "        logits = tf.layers.dense(self.logits, dimension_output,\n",
    "                                         kernel_initializer=create_initializer())\n",
    "        \n",
    "        y_t = self.Y\n",
    "        log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(\n",
    "            logits, y_t, self.lengths\n",
    "        )\n",
    "        self.cost = tf.reduce_mean(-log_likelihood)\n",
    "        self.optimizer = tf.train.AdamOptimizer(\n",
    "            learning_rate = learning_rate\n",
    "        ).minimize(self.cost)\n",
    "        mask = tf.sequence_mask(self.lengths, maxlen = self.maxlen)\n",
    "        self.tags_seq, tags_score = tf.contrib.crf.crf_decode(\n",
    "            logits, transition_params, self.lengths\n",
    "        )\n",
    "        self.tags_seq = tf.identity(self.tags_seq, name = 'logits')\n",
    "\n",
    "        y_t = tf.cast(y_t, tf.int32)\n",
    "        self.prediction = tf.boolean_mask(self.tags_seq, mask)\n",
    "        mask_label = tf.boolean_mask(y_t, mask)\n",
    "        correct_pred = tf.equal(self.prediction, mask_label)\n",
    "        correct_index = tf.cast(correct_pred, tf.float32)\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open('dictionary-entities.json') as fopen:\n",
    "    dictionary = json.load(fopen)\n",
    "\n",
    "tag2idx = dictionary['tag2idx']\n",
    "idx2tag = {int(k): v for k, v in dictionary['idx2tag'].items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py:507: calling count_nonzero (from tensorflow.python.ops.math_ops) with axis is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "reduction_indices is deprecated, use axis instead\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/keras/initializers.py:119: calling RandomUniform.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "If using Keras pass *_constraint arguments to layers.\n",
      "WARNING:tensorflow:From /home/husein/malaya/malaya/pretrained-model/fastformer/fast_transformer/fast_attention.py:87: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n",
      "Tensor(\"fast_transformer/pre_norm/fast_attention/Select:0\", shape=(?, 12, ?), dtype=float32)\n",
      "Tensor(\"fast_transformer/pre_norm/fast_attention/Select:0\", shape=(?, 12, ?), dtype=float32)\n",
      "Tensor(\"fast_transformer/pre_norm_2/fast_attention_1/Select:0\", shape=(?, 12, ?), dtype=float32)\n",
      "Tensor(\"fast_transformer/pre_norm_2/fast_attention_1/Select:0\", shape=(?, 12, ?), dtype=float32)\n",
      "Tensor(\"fast_transformer/pre_norm_4/fast_attention_2/Select:0\", shape=(?, 12, ?), dtype=float32)\n",
      "Tensor(\"fast_transformer/pre_norm_4/fast_attention_2/Select:0\", shape=(?, 12, ?), dtype=float32)\n",
      "Tensor(\"fast_transformer/pre_norm_6/fast_attention_3/Select:0\", shape=(?, 12, ?), dtype=float32)\n",
      "Tensor(\"fast_transformer/pre_norm_6/fast_attention_3/Select:0\", shape=(?, 12, ?), dtype=float32)\n",
      "Tensor(\"fast_transformer/pre_norm_8/fast_attention_4/Select:0\", shape=(?, 12, ?), dtype=float32)\n",
      "Tensor(\"fast_transformer/pre_norm_8/fast_attention_4/Select:0\", shape=(?, 12, ?), dtype=float32)\n",
      "Tensor(\"fast_transformer/pre_norm_10/fast_attention_5/Select:0\", shape=(?, 12, ?), dtype=float32)\n",
      "Tensor(\"fast_transformer/pre_norm_10/fast_attention_5/Select:0\", shape=(?, 12, ?), dtype=float32)\n",
      "Tensor(\"fast_transformer/pre_norm_12/fast_attention_6/Select:0\", shape=(?, 12, ?), dtype=float32)\n",
      "Tensor(\"fast_transformer/pre_norm_12/fast_attention_6/Select:0\", shape=(?, 12, ?), dtype=float32)\n",
      "Tensor(\"fast_transformer/pre_norm_14/fast_attention_7/Select:0\", shape=(?, 12, ?), dtype=float32)\n",
      "Tensor(\"fast_transformer/pre_norm_14/fast_attention_7/Select:0\", shape=(?, 12, ?), dtype=float32)\n",
      "Tensor(\"fast_transformer/pre_norm_16/fast_attention_8/Select:0\", shape=(?, 12, ?), dtype=float32)\n",
      "Tensor(\"fast_transformer/pre_norm_16/fast_attention_8/Select:0\", shape=(?, 12, ?), dtype=float32)\n",
      "Tensor(\"fast_transformer/pre_norm_18/fast_attention_9/Select:0\", shape=(?, 12, ?), dtype=float32)\n",
      "Tensor(\"fast_transformer/pre_norm_18/fast_attention_9/Select:0\", shape=(?, 12, ?), dtype=float32)\n",
      "Tensor(\"fast_transformer/pre_norm_20/fast_attention_10/Select:0\", shape=(?, 12, ?), dtype=float32)\n",
      "Tensor(\"fast_transformer/pre_norm_20/fast_attention_10/Select:0\", shape=(?, 12, ?), dtype=float32)\n",
      "Tensor(\"fast_transformer/pre_norm_22/fast_attention_11/Select:0\", shape=(?, 12, ?), dtype=float32)\n",
      "Tensor(\"fast_transformer/pre_norm_22/fast_attention_11/Select:0\", shape=(?, 12, ?), dtype=float32)\n",
      "WARNING:tensorflow:From <ipython-input-8-361fb0fa976c>:29: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.Dense instead.\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/layers/core.py:187: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.__call__` method instead.\n",
      "WARNING:tensorflow:\n",
      "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "  * https://github.com/tensorflow/io (for I/O related ops)\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/contrib/crf/python/ops/crf.py:213: dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `keras.layers.RNN(cell)`, which is equivalent to this API\n"
     ]
    }
   ],
   "source": [
    "dimension_output = len(tag2idx)\n",
    "learning_rate = 2e-5\n",
    "\n",
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model(\n",
    "    dimension_output,\n",
    "    learning_rate\n",
    ")\n",
    "\n",
    "sess.run(tf.global_variables_initializer())\n",
    "var_lists = tf.trainable_variables()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "import re\n",
    "\n",
    "def get_assignment_map_from_checkpoint(tvars, init_checkpoint):\n",
    "    \"\"\"Compute the union of the current variables and checkpoint variables.\"\"\"\n",
    "    assignment_map = {}\n",
    "    initialized_variable_names = {}\n",
    "\n",
    "    name_to_variable = collections.OrderedDict()\n",
    "    for var in tvars:\n",
    "        name = var.name\n",
    "        m = re.match('^(.*):\\\\d+$', name)\n",
    "        if m is not None:\n",
    "            name = m.group(1)\n",
    "        name_to_variable[name] = var\n",
    "\n",
    "    init_vars = tf.train.list_variables(init_checkpoint)\n",
    "\n",
    "    assignment_map = collections.OrderedDict()\n",
    "    for x in init_vars:\n",
    "        (name, var) = (x[0], x[1])\n",
    "        if name not in name_to_variable:\n",
    "            continue\n",
    "        assignment_map[name] = name_to_variable[name]\n",
    "        initialized_variable_names[name] = 1\n",
    "        initialized_variable_names[name + ':0'] = 1\n",
    "\n",
    "    return (assignment_map, initialized_variable_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "tvars = tf.trainable_variables()\n",
    "checkpoint = 'fastformer-base/model.ckpt-500000'\n",
    "assignment_map, initialized_variable_names = get_assignment_map_from_checkpoint(tvars, \n",
    "                                                                                checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from fastformer-base/model.ckpt-500000\n"
     ]
    }
   ],
   "source": [
    "saver = tf.train.Saver(var_list = assignment_map)\n",
    "saver.restore(sess, checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "pad_sequences = tf.keras.preprocessing.sequence.pad_sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "string = 'KUALA LUMPUR: Sempena sambutan Aidilfitri minggu depan, Perdana Menteri Tun Dr Mahathir Mohamad dan Menteri Pengangkutan Anthony Loke Siew Fook menitipkan pesanan khas kepada orang ramai yang mahu pulang ke kampung halaman masing-masing. Dalam video pendek terbitan Jabatan Keselamatan Jalan Raya (JKJR) itu, Dr Mahathir menasihati mereka supaya berhenti berehat dan tidur sebentar  sekiranya mengantuk ketika memandu.'\n",
    "\n",
    "import re\n",
    "\n",
    "def entities_textcleaning(string, lowering = False):\n",
    "    \"\"\"\n",
    "    use by entities recognition, pos recognition and dependency parsing\n",
    "    \"\"\"\n",
    "    string = re.sub('[^A-Za-z0-9\\-\\/() ]+', ' ', string)\n",
    "    string = re.sub(r'[ ]+', ' ', string).strip()\n",
    "    original_string = string.split()\n",
    "    if lowering:\n",
    "        string = string.lower()\n",
    "    string = [\n",
    "        (original_string[no], word.title() if word.isupper() else word)\n",
    "        for no, word in enumerate(string.split())\n",
    "        if len(word)\n",
    "    ]\n",
    "    return [s[0] for s in string], [s[1] for s in string]\n",
    "\n",
    "def parse_X(left):\n",
    "    bert_tokens = ['[CLS]']\n",
    "    for no, orig_token in enumerate(left):\n",
    "        t = tokenizer.tokenize(orig_token)\n",
    "        bert_tokens.extend(t)\n",
    "    bert_tokens.append(\"[SEP]\")\n",
    "    input_mask = [1] * len(bert_tokens)\n",
    "    return tokenizer.convert_tokens_to_ids(bert_tokens), bert_tokens, input_mask\n",
    "\n",
    "sequence = entities_textcleaning(string)[1]\n",
    "parsed_sequence, bert_sequence, input_mask = parse_X(sequence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "predicted = sess.run(model.tags_seq,\n",
    "                feed_dict = {\n",
    "                    model.X: [parsed_sequence],\n",
    "                },\n",
    "        )[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def merge_wordpiece_tokens_tagging(x, y):\n",
    "    new_paired_tokens = []\n",
    "    n_tokens = len(x)\n",
    "    rejected = ['[CLS]', '[SEP]', '[PAD]']\n",
    "\n",
    "    i = 0\n",
    "\n",
    "    while i < n_tokens:\n",
    "        current_token, current_label = x[i], y[i]\n",
    "        if current_token.startswith('##'):\n",
    "            previous_token, previous_label = new_paired_tokens.pop()\n",
    "            merged_token = previous_token\n",
    "            merged_label = [previous_label]\n",
    "            while current_token.startswith('##'):\n",
    "                merged_token = merged_token + current_token.replace('##', '')\n",
    "                merged_label.append(current_label)\n",
    "                i = i + 1\n",
    "                current_token, current_label = x[i], y[i]\n",
    "            merged_label = merged_label[0]\n",
    "            new_paired_tokens.append((merged_token, merged_label))\n",
    "\n",
    "        else:\n",
    "            new_paired_tokens.append((current_token, current_label))\n",
    "            i = i + 1\n",
    "\n",
    "    words = [\n",
    "        i[0]\n",
    "        for i in new_paired_tokens\n",
    "        if i[0] not in rejected\n",
    "    ]\n",
    "    labels = [i[1] for i in new_paired_tokens if i[0] not in rejected]\n",
    "    return words, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('Kuala', 'event'),\n",
       " ('Lumpur', 'X'),\n",
       " ('Sempena', 'X'),\n",
       " ('sambutan', 'organization'),\n",
       " ('Aidilfitri', 'X'),\n",
       " ('minggu', 'person'),\n",
       " ('depan', 'X'),\n",
       " ('Perdana', 'X'),\n",
       " ('Menteri', 'X'),\n",
       " ('Tun', 'organization'),\n",
       " ('Dr', 'person'),\n",
       " ('Mahathir', 'X'),\n",
       " ('Mohamad', 'location'),\n",
       " ('dan', 'X'),\n",
       " ('Menteri', 'X'),\n",
       " ('Pengangkutan', 'organization'),\n",
       " ('Anthony', 'X'),\n",
       " ('Loke', 'X'),\n",
       " ('Siew', 'location'),\n",
       " ('Fook', 'X'),\n",
       " ('menitipkan', 'X'),\n",
       " ('pesanan', 'organization'),\n",
       " ('khas', 'X'),\n",
       " ('kepada', 'person'),\n",
       " ('orang', 'X'),\n",
       " ('ramai', 'X'),\n",
       " ('yang', 'X'),\n",
       " ('mahu', 'organization'),\n",
       " ('pulang', 'X'),\n",
       " ('ke', 'X'),\n",
       " ('kampung', 'event'),\n",
       " ('halaman', 'X'),\n",
       " ('masing', 'event'),\n",
       " ('-', 'person'),\n",
       " ('masing', 'X'),\n",
       " ('Dalam', 'event'),\n",
       " ('video', 'person'),\n",
       " ('pendek', 'X'),\n",
       " ('terbitan', 'X'),\n",
       " ('Jabatan', 'X'),\n",
       " ('Keselamatan', 'event'),\n",
       " ('Jalan', 'organization'),\n",
       " ('Raya', 'X'),\n",
       " ('(', 'event'),\n",
       " ('Jkjr', 'X'),\n",
       " (')', 'event'),\n",
       " ('itu', 'X'),\n",
       " ('Dr', 'event'),\n",
       " ('Mahathir', 'person'),\n",
       " ('menasihati', 'X'),\n",
       " ('mereka', 'event'),\n",
       " ('supaya', 'X'),\n",
       " ('berhenti', 'organization'),\n",
       " ('berehat', 'X'),\n",
       " ('dan', 'X'),\n",
       " ('tidur', 'X'),\n",
       " ('sebentar', 'event'),\n",
       " ('sekiranya', 'X'),\n",
       " ('mengantuk', 'X'),\n",
       " ('ketika', 'event'),\n",
       " ('memandu', 'X')]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "merged = merge_wordpiece_tokens_tagging(bert_sequence, [idx2tag[d] for d in predicted])\n",
    "list(zip(merged[0], merged[1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 32380/32380 [3:39:17<00:00,  2.46it/s, accuracy=1, cost=0.0738]     \n",
      "test minibatch loop: 100%|██████████| 4688/4688 [17:29<00:00,  4.47it/s, accuracy=0.979, cost=5.69]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 14207.170645713806\n",
      "epoch: 0, training loss: 1.841172, training acc: 0.991068, valid loss: 6.252518, valid acc: 0.980113\n",
      "\n",
      "[('Kuala', 'location'), ('Lumpur', 'location'), ('Sempena', 'location'), ('sambutan', 'OTHER'), ('Aidilfitri', 'event'), ('minggu', 'time'), ('depan', 'time'), ('Perdana', 'person'), ('Menteri', 'person'), ('Tun', 'person'), ('Dr', 'person'), ('Mahathir', 'person'), ('Mohamad', 'person'), ('dan', 'OTHER'), ('Menteri', 'person'), ('Pengangkutan', 'person'), ('Anthony', 'person'), ('Loke', 'person'), ('Siew', 'person'), ('Fook', 'person'), ('menitipkan', 'X'), ('pesanan', 'OTHER'), ('khas', 'OTHER'), ('kepada', 'OTHER'), ('orang', 'OTHER'), ('ramai', 'OTHER'), ('yang', 'OTHER'), ('mahu', 'OTHER'), ('pulang', 'OTHER'), ('ke', 'OTHER'), ('kampung', 'OTHER'), ('halaman', 'OTHER'), ('masing', 'OTHER'), ('-', 'X'), ('masing', 'X'), ('Dalam', 'OTHER'), ('video', 'OTHER'), ('pendek', 'OTHER'), ('terbitan', 'OTHER'), ('Jabatan', 'organization'), ('Keselamatan', 'organization'), ('Jalan', 'location'), ('Raya', 'location'), ('(', 'location'), ('Jkjr', 'X'), (')', 'X'), ('itu', 'OTHER'), ('Dr', 'person'), ('Mahathir', 'person'), ('menasihati', 'OTHER'), ('mereka', 'OTHER'), ('supaya', 'OTHER'), ('berhenti', 'OTHER'), ('berehat', 'OTHER'), ('dan', 'OTHER'), ('tidur', 'OTHER'), ('sebentar', 'OTHER'), ('sekiranya', 'OTHER'), ('mengantuk', 'OTHER'), ('ketika', 'OTHER'), ('memandu', 'OTHER')]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "import time\n",
    "\n",
    "for e in range(1):\n",
    "    lasttime = time.time()\n",
    "    train_acc, train_loss, test_acc, test_loss = [], [], [], []\n",
    "    pbar = tqdm(\n",
    "        range(0, len(train_X), batch_size), desc = 'train minibatch loop'\n",
    "    )\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(train_X))\n",
    "        batch_x = train_X[i : index]\n",
    "        batch_y = train_Y[i : index]\n",
    "        batch_x = pad_sequences(batch_x, padding='post')\n",
    "        batch_y = pad_sequences(batch_y, padding='post')\n",
    "        \n",
    "        acc, cost, _ = sess.run(\n",
    "            [model.accuracy, model.cost, model.optimizer],\n",
    "            feed_dict = {\n",
    "                model.X: batch_x,\n",
    "                model.Y: batch_y,\n",
    "            },\n",
    "        )\n",
    "        assert not np.isnan(cost)\n",
    "        train_loss.append(cost)\n",
    "        train_acc.append(acc)\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "    \n",
    "    pbar = tqdm(\n",
    "        range(0, len(test_X), batch_size), desc = 'test minibatch loop'\n",
    "    )\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(test_X))\n",
    "        batch_x = test_X[i : index]\n",
    "        batch_y = test_Y[i : index]\n",
    "        batch_x = pad_sequences(batch_x, padding='post')\n",
    "        batch_y = pad_sequences(batch_y, padding='post')\n",
    "        \n",
    "        acc, cost = sess.run(\n",
    "            [model.accuracy, model.cost],\n",
    "            feed_dict = {\n",
    "                model.X: batch_x,\n",
    "                model.Y: batch_y,\n",
    "            },\n",
    "        )\n",
    "        assert not np.isnan(cost)\n",
    "        test_loss.append(cost)\n",
    "        test_acc.append(acc)\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "    \n",
    "    train_loss = np.mean(train_loss)\n",
    "    train_acc = np.mean(train_acc)\n",
    "    test_loss = np.mean(test_loss)\n",
    "    test_acc = np.mean(test_acc)\n",
    "\n",
    "    print('time taken:', time.time() - lasttime)\n",
    "    print(\n",
    "        'epoch: %d, training loss: %f, training acc: %f, valid loss: %f, valid acc: %f\\n'\n",
    "        % (e, train_loss, train_acc, test_loss, test_acc)\n",
    "    )\n",
    "    predicted = sess.run(model.tags_seq,\n",
    "                feed_dict = {\n",
    "                    model.X: [parsed_sequence],\n",
    "                },\n",
    "        )[0]\n",
    "    merged = merge_wordpiece_tokens_tagging(bert_sequence, [idx2tag[d] for d in predicted])\n",
    "    print(list(zip(merged[0], merged[1])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[('Kuala', 'location'), ('Lumpur', 'location'), ('Sempena', 'location'), ('sambutan', 'OTHER'), ('Aidilfitri', 'event'), ('minggu', 'time'), ('depan', 'time'), ('Perdana', 'person'), ('Menteri', 'person'), ('Tun', 'person'), ('Dr', 'person'), ('Mahathir', 'person'), ('Mohamad', 'person'), ('dan', 'OTHER'), ('Menteri', 'person'), ('Pengangkutan', 'person'), ('Anthony', 'person'), ('Loke', 'person'), ('Siew', 'person'), ('Fook', 'person'), ('menitipkan', 'X'), ('pesanan', 'OTHER'), ('khas', 'OTHER'), ('kepada', 'OTHER'), ('orang', 'OTHER'), ('ramai', 'OTHER'), ('yang', 'OTHER'), ('mahu', 'OTHER'), ('pulang', 'OTHER'), ('ke', 'OTHER'), ('kampung', 'OTHER'), ('halaman', 'OTHER'), ('masing', 'OTHER'), ('-', 'X'), ('masing', 'X'), ('Dalam', 'OTHER'), ('video', 'OTHER'), ('pendek', 'OTHER'), ('terbitan', 'OTHER'), ('Jabatan', 'organization'), ('Keselamatan', 'organization'), ('Jalan', 'location'), ('Raya', 'location'), ('(', 'location'), ('Jkjr', 'X'), (')', 'X'), ('itu', 'OTHER'), ('Dr', 'person'), ('Mahathir', 'person'), ('menasihati', 'OTHER'), ('mereka', 'OTHER'), ('supaya', 'OTHER'), ('berhenti', 'OTHER'), ('berehat', 'OTHER'), ('dan', 'OTHER'), ('tidur', 'OTHER'), ('sebentar', 'OTHER'), ('sekiranya', 'OTHER'), ('mengantuk', 'OTHER'), ('ketika', 'OTHER'), ('memandu', 'OTHER')]\n"
     ]
    }
   ],
   "source": [
    "predicted = sess.run(model.tags_seq,\n",
    "            feed_dict = {\n",
    "                model.X: [parsed_sequence],\n",
    "            },\n",
    "    )[0]\n",
    "merged = merge_wordpiece_tokens_tagging(bert_sequence, [idx2tag[d] for d in predicted])\n",
    "print(list(zip(merged[0], merged[1])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'fastformer-base-entities/model.ckpt'"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "saver = tf.train.Saver(tf.trainable_variables())\n",
    "saver.save(sess, 'fastformer-base-entities/model.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pred2label(pred):\n",
    "    out = []\n",
    "    for pred_i in pred:\n",
    "        out_i = []\n",
    "        for p in pred_i:\n",
    "            out_i.append(idx2tag[p])\n",
    "        out.append(out_i)\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "validation minibatch loop: 100%|██████████| 4688/4688 [18:12<00:00,  4.29it/s]\n"
     ]
    }
   ],
   "source": [
    "real_Y, predict_Y = [], []\n",
    "\n",
    "pbar = tqdm(\n",
    "    range(0, len(test_X), batch_size), desc = 'validation minibatch loop'\n",
    ")\n",
    "for i in pbar:\n",
    "    index = min(i + batch_size, len(test_X))\n",
    "    batch_x = test_X[i : index]\n",
    "    batch_y = test_Y[i : index]\n",
    "    batch_x = pad_sequences(batch_x, padding='post')\n",
    "    batch_y = pad_sequences(batch_y, padding='post')\n",
    "    predicted = pred2label(sess.run(model.tags_seq,\n",
    "            feed_dict = {\n",
    "                model.X: batch_x,\n",
    "            },\n",
    "    ))\n",
    "    real = pred2label(batch_y)\n",
    "    predict_Y.extend(predicted)\n",
    "    real_Y.extend(real)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "temp_real_Y = []\n",
    "for r in real_Y:\n",
    "    temp_real_Y.extend(r)\n",
    "    \n",
    "temp_predict_Y = []\n",
    "for r in predict_Y:\n",
    "    temp_predict_Y.extend(r)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(14464576, 14464576)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(temp_real_Y), len(temp_predict_Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "       OTHER    0.98377   0.98598   0.98487   4877783\n",
      "         PAD    1.00000   1.00000   1.00000   3235654\n",
      "           X    0.99427   0.98744   0.99084   3728922\n",
      "       event    0.87362   0.89726   0.88528     38359\n",
      "         law    0.95097   0.88134   0.91483     33011\n",
      "    location    0.95822   0.98902   0.97338   1508352\n",
      "organization    0.89477   0.87946   0.88705    231711\n",
      "      person    0.97205   0.91806   0.94428    483093\n",
      "    quantity    0.94198   0.92804   0.93495    133054\n",
      "        time    0.93342   0.93521   0.93431    194637\n",
      "\n",
      "    accuracy                        0.98414  14464576\n",
      "   macro avg    0.95031   0.94018   0.94498  14464576\n",
      "weighted avg    0.98420   0.98414   0.98411  14464576\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "print(classification_report(temp_real_Y, temp_predict_Y, digits = 5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Placeholder',\n",
       " 'Placeholder_1',\n",
       " 'dense/kernel',\n",
       " 'dense/bias',\n",
       " 'transitions',\n",
       " 'logits']"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "strings = ','.join(\n",
    "    [\n",
    "        n.name\n",
    "        for n in tf.get_default_graph().as_graph_def().node\n",
    "        if ('Variable' in n.op\n",
    "        or 'Placeholder' in n.name\n",
    "        or 'logits' in n.name\n",
    "        or 'alphas' in n.name\n",
    "        or 'self/Softmax' in n.name)\n",
    "        and 'adam' not in n.name\n",
    "        and 'beta' not in n.name\n",
    "        and 'global_step' not in n.name\n",
    "        and 'ReadVariableOp' not in n.name\n",
    "        and 'AssignVariableOp' not in n.name\n",
    "        and '/Assign' not in n.name\n",
    "        and '/Adam' not in n.name\n",
    "    ]\n",
    ")\n",
    "strings.split(',')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "def freeze_graph(model_dir, output_node_names):\n",
    "\n",
    "    if not tf.gfile.Exists(model_dir):\n",
    "        raise AssertionError(\n",
    "            \"Export directory doesn't exists. Please specify an export \"\n",
    "            'directory: %s' % model_dir\n",
    "        )\n",
    "\n",
    "    checkpoint = tf.train.get_checkpoint_state(model_dir)\n",
    "    input_checkpoint = checkpoint.model_checkpoint_path\n",
    "\n",
    "    absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])\n",
    "    output_graph = absolute_model_dir + '/frozen_model.pb'\n",
    "    clear_devices = True\n",
    "    with tf.Session(graph = tf.Graph()) as sess:\n",
    "        saver = tf.train.import_meta_graph(\n",
    "            input_checkpoint + '.meta', clear_devices = clear_devices\n",
    "        )\n",
    "        saver.restore(sess, input_checkpoint)\n",
    "        output_graph_def = tf.graph_util.convert_variables_to_constants(\n",
    "            sess,\n",
    "            tf.get_default_graph().as_graph_def(),\n",
    "            output_node_names.split(','),\n",
    "        )\n",
    "        with tf.gfile.GFile(output_graph, 'wb') as f:\n",
    "            f.write(output_graph_def.SerializeToString())\n",
    "        print('%d ops in the final graph.' % len(output_graph_def.node))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from fastformer-base-entities/model.ckpt\n",
      "WARNING:tensorflow:From <ipython-input-28-9a7215a4e58a>:23: convert_variables_to_constants (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.graph_util.convert_variables_to_constants`\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/framework/graph_util_impl.py:277: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.graph_util.extract_sub_graph`\n",
      "INFO:tensorflow:Froze 163 variables.\n",
      "INFO:tensorflow:Converted 163 variables to const ops.\n",
      "7024 ops in the final graph.\n"
     ]
    }
   ],
   "source": [
    "freeze_graph('fastformer-base-entities', strings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_graph(frozen_graph_filename):\n",
    "    with tf.gfile.GFile(frozen_graph_filename, 'rb') as f:\n",
    "        graph_def = tf.GraphDef()\n",
    "        graph_def.ParseFromString(f.read())\n",
    "    with tf.Graph().as_default() as graph:\n",
    "        tf.import_graph_def(graph_def)\n",
    "    return graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = load_graph('fastformer-base-entities/frozen_model.pb')\n",
    "x = g.get_tensor_by_name('import/Placeholder:0')\n",
    "logits = g.get_tensor_by_name('import/logits:0')\n",
    "test_sess = tf.InteractiveSession(graph = g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[('Kuala', 'location'), ('Lumpur', 'location'), ('Sempena', 'location'), ('sambutan', 'OTHER'), ('Aidilfitri', 'event'), ('minggu', 'time'), ('depan', 'time'), ('Perdana', 'person'), ('Menteri', 'person'), ('Tun', 'person'), ('Dr', 'person'), ('Mahathir', 'person'), ('Mohamad', 'person'), ('dan', 'OTHER'), ('Menteri', 'person'), ('Pengangkutan', 'person'), ('Anthony', 'person'), ('Loke', 'person'), ('Siew', 'person'), ('Fook', 'person'), ('menitipkan', 'X'), ('pesanan', 'OTHER'), ('khas', 'OTHER'), ('kepada', 'OTHER'), ('orang', 'OTHER'), ('ramai', 'OTHER'), ('yang', 'OTHER'), ('mahu', 'OTHER'), ('pulang', 'OTHER'), ('ke', 'OTHER'), ('kampung', 'OTHER'), ('halaman', 'OTHER'), ('masing', 'OTHER'), ('-', 'X'), ('masing', 'X'), ('Dalam', 'OTHER'), ('video', 'OTHER'), ('pendek', 'OTHER'), ('terbitan', 'OTHER'), ('Jabatan', 'organization'), ('Keselamatan', 'organization'), ('Jalan', 'location'), ('Raya', 'location'), ('(', 'location'), ('Jkjr', 'X'), (')', 'X'), ('itu', 'OTHER'), ('Dr', 'person'), ('Mahathir', 'person'), ('menasihati', 'OTHER'), ('mereka', 'OTHER'), ('supaya', 'OTHER'), ('berhenti', 'OTHER'), ('berehat', 'OTHER'), ('dan', 'OTHER'), ('tidur', 'OTHER'), ('sebentar', 'OTHER'), ('sekiranya', 'OTHER'), ('mengantuk', 'OTHER'), ('ketika', 'OTHER'), ('memandu', 'OTHER')]\n",
      "CPU times: user 2.15 s, sys: 115 ms, total: 2.27 s\n",
      "Wall time: 2.07 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "predicted = test_sess.run(logits,\n",
    "            feed_dict = {\n",
    "                x: [parsed_sequence],\n",
    "            },\n",
    "    )[0]\n",
    "merged = merge_wordpiece_tokens_tagging(bert_sequence, [idx2tag[d] for d in predicted])\n",
    "print(list(zip(merged[0], merged[1])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.tools.graph_transforms import TransformGraph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-34-c64b87be774d>:19: FastGFile.__init__ (from tensorflow.python.platform.gfile) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.gfile.GFile.\n"
     ]
    }
   ],
   "source": [
    "transforms = ['add_default_attributes',\n",
    "             'remove_nodes(op=Identity, op=CheckNumerics, op=Dropout)',\n",
    "             'fold_batch_norms',\n",
    "             'fold_old_batch_norms',\n",
    "             'quantize_weights(fallback_min=-10, fallback_max=10)',\n",
    "             'strip_unused_nodes',\n",
    "             'sort_by_execution_order']\n",
    "\n",
    "input_nodes = [\n",
    "    'Placeholder',\n",
    "]\n",
    "output_nodes = [\n",
    "    'logits',\n",
    "]\n",
    "\n",
    "pb = 'fastformer-base-entities/frozen_model.pb'\n",
    "\n",
    "input_graph_def = tf.GraphDef()\n",
    "with tf.gfile.FastGFile(pb, 'rb') as f:\n",
    "    input_graph_def.ParseFromString(f.read())\n",
    "\n",
    "transformed_graph_def = TransformGraph(input_graph_def, \n",
    "                                           input_nodes,\n",
    "                                           output_nodes, transforms)\n",
    "    \n",
    "with tf.gfile.GFile(f'{pb}.quantized', 'wb') as f:\n",
    "    f.write(transformed_graph_def.SerializeToString())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = load_graph('fastformer-tiny-entities/frozen_model.pb.quantized')\n",
    "x = g.get_tensor_by_name('import/Placeholder:0')\n",
    "logits = g.get_tensor_by_name('import/logits:0')\n",
    "test_sess = tf.InteractiveSession(graph = g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[('Kuala', 'location'), ('Lumpur', 'location'), ('Sempena', 'location'), ('sambutan', 'OTHER'), ('Aidilfitri', 'event'), ('minggu', 'time'), ('depan', 'time'), ('Perdana', 'person'), ('Menteri', 'person'), ('Tun', 'person'), ('Dr', 'person'), ('Mahathir', 'person'), ('Mohamad', 'person'), ('dan', 'OTHER'), ('Menteri', 'person'), ('Pengangkutan', 'person'), ('Anthony', 'person'), ('Loke', 'person'), ('Siew', 'event'), ('Fook', 'event'), ('menitipkan', 'person'), ('pesanan', 'OTHER'), ('khas', 'OTHER'), ('kepada', 'OTHER'), ('orang', 'OTHER'), ('ramai', 'OTHER'), ('yang', 'OTHER'), ('mahu', 'OTHER'), ('pulang', 'OTHER'), ('ke', 'OTHER'), ('kampung', 'OTHER'), ('halaman', 'OTHER'), ('masing', 'OTHER'), ('-', 'X'), ('masing', 'X'), ('Dalam', 'OTHER'), ('video', 'OTHER'), ('pendek', 'OTHER'), ('terbitan', 'organization'), ('Jabatan', 'organization'), ('Keselamatan', 'organization'), ('Jalan', 'organization'), ('Raya', 'event'), ('(', 'organization'), ('Jkjr', 'X'), (')', 'X'), ('itu', 'OTHER'), ('Dr', 'person'), ('Mahathir', 'person'), ('menasihati', 'OTHER'), ('mereka', 'OTHER'), ('supaya', 'OTHER'), ('berhenti', 'OTHER'), ('berehat', 'time'), ('dan', 'OTHER'), ('tidur', 'OTHER'), ('sebentar', 'OTHER'), ('sekiranya', 'OTHER'), ('mengantuk', 'X'), ('ketika', 'OTHER'), ('memandu', 'OTHER')]\n",
      "CPU times: user 1.03 s, sys: 151 ms, total: 1.18 s\n",
      "Wall time: 904 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "predicted = test_sess.run(logits,\n",
    "            feed_dict = {\n",
    "                x: [parsed_sequence],\n",
    "            },\n",
    "    )[0]\n",
    "merged = merge_wordpiece_tokens_tagging(bert_sequence, [idx2tag[d] for d in predicted])\n",
    "print(list(zip(merged[0], merged[1])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<b2sdk.file_version.FileVersionInfo at 0x7f528c1d3fd0>"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "file = 'fastformer-base-entities/frozen_model.pb'\n",
    "outPutname = 'entity/fastformer/model.pb'\n",
    "b2_bucket.upload_local_file(\n",
    "    local_file=file,\n",
    "    file_name=outPutname,\n",
    "    file_infos=file_info,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<b2sdk.file_version.FileVersionInfo at 0x7f5104131be0>"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "file = 'fastformer-base-entities/frozen_model.pb.quantized'\n",
    "outPutname = 'entity/fastformer-quantized/model.pb'\n",
    "b2_bucket.upload_local_file(\n",
    "    local_file=file,\n",
    "    file_name=outPutname,\n",
    "    file_infos=file_info,\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
